import jsonlines
import re
import sys
import os
import pdb


def evaluate(p_file: str, s_file: str):
    prediction_data = [d for d in jsonlines.open(p_file)]
    source_data = [d for d in jsonlines.open(s_file)][:len(prediction_data)]

    pattern = pattern_option if ord("A") <= ord(source_data[0]['target']) <= ord("Z") else pattern_number

    count = 0
    for pd, sd in zip(prediction_data, source_data):
        # pdb.set_trace()
        predict_string = pd["R"].split("<|im_end|>")[2]

        try:
            prediction = re.findall(pattern, predict_string)[0]
        except:
            prediction = "None"
        if prediction[1] == sd["target"]:
            count += 1
    return count / len(prediction_data), len(prediction_data), count


if __name__ == "__main__":
    benchmark = sys.argv[1]

    tasks = os.listdir(f"./output/ood/{benchmark}/")
    tasks = [t for t in tasks if t.endswith(".jsonl")]

    pattern_option = re.compile("\([A-Z]\)")
    pattern_number = re.compile("[0-9]+")
    pattern_bool = re.compile("True|False")

    all_number, all_hit = 0, 0

    for task in tasks:
        predict_file = f"./output/ood/{benchmark}/{task}"
        source_file = f"./data/ood/{benchmark}/{task}"

        task_acc, number, hit = evaluate(predict_file, source_file)
        all_number += number
        all_hit += hit

        print(f"{task}: {task_acc}")


    print(f"Overall: {all_hit / all_number}")




















